// Copyright (c) 2020-present, INSPUR Co, Ltd. All rights reserved.
// This source code is licensed under Apache 2.0 License.
//
// Created by florian on 05.08.15.

#include "tree_node.h"
#include "pure_mem/epoche.h"
#include <algorithm>
#include <assert.h>
#include <emmintrin.h> // x86 SSE intrinsics

namespace syn_art_fullkey {

void N::change(N *node, uint8_t key, N *val) {
  switch (node->getType()) {
  case NTypes::NodeType4: {
    auto n = static_cast<N4 *>(node);
    n->change(key, val);
    return;
  }
  case NTypes::NodeType16: {
    auto n = static_cast<N16 *>(node);
    n->change(key, val);
    return;
  }
  case NTypes::NodeType48: {
    auto n = static_cast<N48 *>(node);
    n->change(key, val);
    return;
  }
  case NTypes::NodeType256: {
    auto n = static_cast<N256 *>(node);
    n->change(key, val);
    return;
  }
  }
  assert(false);
  __builtin_unreachable();
}

void N::insertGrow(N *&n, N *parentNode, uint8_t keyParent, uint8_t key, N *val,
                   bool &needRestart) {
  if (N::insert(n, key, val)) {
    return;
  }
  rocksdb::Slice pref;
  pref.data_ = N::getPrefix(n).prefix_;
  pref.size_ = N::getPrefix(n).length_;
  auto nBig = N::newTreeNode((NTypes)((uint8_t)(n->getType()) + 1),
                             n->getLevel(), pref);
  N::copyTo(n, nBig);
  nBig->setFullKeyLeaf(nullptr, n->getFullKeyLeaf());
  N::insert(nBig, key, val);
  nBig->lockVersionOrRestart(needRestart);
  assert(!needRestart);
  parentNode->writeLockOrRestart(needRestart);
  if (needRestart) {
    free(nBig);
    return;
  }

  N::change(parentNode, keyParent, nBig);
  parentNode->writeUnlock();
  N *oldNode = n;
  n = nBig;
  oldNode->writeUnlockObsolete();
  rocksdb::DeleteWhileNoRefs::getInstance()->markNodeForDeletion(oldNode);
}

void N::copyTo(N *srcNode, N *destNode) {
  switch (srcNode->getType()) {
  case NTypes::NodeType4: {
    auto n = static_cast<N4 *>(srcNode);
    return n->copyTo(destNode);
  }
  case NTypes::NodeType16: {
    auto n = static_cast<N16 *>(srcNode);
    return n->copyTo(destNode);
  }
  case NTypes::NodeType48: {
    auto n = static_cast<N48 *>(srcNode);
    return n->copyTo(destNode);
  }
  case NTypes::NodeType256: {
    auto n = static_cast<N256 *>(srcNode);
    return n->copyTo(destNode);
  }
  }
}

void N4::copyTo(N *n) const {
  for (uint32_t i = 0; i < count_; ++i) {
    N *child = children_[i].load();
    if (child != nullptr) {
      N::insert(n, keys_[i].load(), child);
    }
  }
}

void N16::copyTo(N *n) const {
  for (unsigned i = 0; i < count_; i++) {
    N *child = children_[i].load();
    if (child != nullptr) {
      N::insert(n, flipSign(keys_[i]), child);
    }
  }
}

void N48::copyTo(N *n) const {
  for (unsigned i = 0; i < 256; i++) {
    uint8_t index = childIndex_[i].load();
    if (index != emptyMarker) {
      N::insert(n, i, children_[index]);
    }
  }
}

void N256::copyTo(N *n) const {
  for (int i = 0; i < 256; ++i) {
    N *child = children_[i].load();
    if (child != nullptr) {
      N::insert(n, i, child);
    }
  }
}

N *N::getChild(const uint8_t k, N *node) {
  switch (node->getType()) {
  case NTypes::NodeType4: {
    auto n = static_cast<N4 *>(node);
    return n->getChild(k);
  }
  case NTypes::NodeType16: {
    auto n = static_cast<N16 *>(node);
    return n->getChild(k);
  }
  case NTypes::NodeType48: {
    auto n = static_cast<N48 *>(node);
    return n->getChild(k);
  }
  case NTypes::NodeType256: {
    auto n = static_cast<N256 *>(node);
    return n->getChild(k);
  }
  }
  assert(false);
  __builtin_unreachable();
}

void N::getChildrenSmall(const N *node, uint8_t start, uint8_t end,
                         std::tuple<uint8_t, N *> children1[],
                         uint32_t &childrenCount, u_int32_t childMax) {
  switch (node->getType()) {
  case NTypes::NodeType4: {
    auto n = static_cast<const N4 *>(node);
    n->getChildrenSmall(start, end, children1, childrenCount, childMax);
    return;
  }
  case NTypes::NodeType16: {
    auto n = static_cast<const N16 *>(node);
    n->getChildrenSmall(start, end, children1, childrenCount, childMax);
    return;
  }
  case NTypes::NodeType48: {
    auto n = static_cast<const N48 *>(node);
    n->getChildrenSmall(start, end, children1, childrenCount, childMax);
    return;
  }
  case NTypes::NodeType256: {
    auto n = static_cast<const N256 *>(node);
    n->getChildrenSmall(start, end, children1, childrenCount, childMax);
    return;
  }
  }
}

void N::getChildrenLarge(const N *node, uint8_t start, uint8_t end,
                         std::tuple<uint8_t, N *> children1[],
                         uint32_t &childrenCount, u_int32_t childMax) {
  switch (node->getType()) {
  case NTypes::NodeType4: {
    auto n = static_cast<const N4 *>(node);
    n->getChildrenLarge(start, end, children1, childrenCount, childMax);
    return;
  }
  case NTypes::NodeType16: {
    auto n = static_cast<const N16 *>(node);
    n->getChildrenLarge(start, end, children1, childrenCount, childMax);
    return;
  }
  case NTypes::NodeType48: {
    auto n = static_cast<const N48 *>(node);
    n->getChildrenLarge(start, end, children1, childrenCount, childMax);
    return;
  }
  case NTypes::NodeType256: {
    auto n = static_cast<const N256 *>(node);
    n->getChildrenLarge(start, end, children1, childrenCount, childMax);
    return;
  }
  }
}

N *N::copyNodeWithNewPrefix(N *node, const rocksdb::Slice &prefix) {
  switch (node->getType()) {
  case NTypes::NodeType4: {
    N4 *cur = (N4 *)malloc(sizeof(N4) + prefix.size());
    memcpy(cur, node, sizeof(N4));
    cur->getPrefix().length_ = prefix.size();
    memcpy(cur->getPrefix().prefix_, prefix.data_, prefix.size());
    return cur;
  }
  case NTypes::NodeType16: {
    N16 *cur = (N16 *)malloc(sizeof(N16) + prefix.size());
    memcpy(cur, node, sizeof(N16));
    cur->getPrefix().length_ = prefix.size();
    memcpy(cur->getPrefix().prefix_, prefix.data_, prefix.size());
    return cur;
  }
  case NTypes::NodeType48: {
    N48 *cur = (N48 *)malloc(sizeof(N48) + prefix.size());
    memcpy(cur, node, sizeof(N48));
    cur->getPrefix().length_ = prefix.size();
    memcpy(cur->getPrefix().prefix_, prefix.data_, prefix.size());
    return cur;
  }
  case NTypes::NodeType256: {
    N256 *cur = (N256 *)malloc(sizeof(N256) + prefix.size());
    memcpy(cur, node, sizeof(N256));
    cur->getPrefix().length_ = prefix.size();
    memcpy(cur->getPrefix().prefix_, prefix.data_, prefix.size());
    return cur;
  }
  }
  assert(false);
  __builtin_unreachable();
}

N *N::nodePrefixChange(N *node, const rocksdb::Slice &prefix) {
  switch (node->getType()) {
  case NTypes::NodeType4: {
    N4 *cur = (N4 *)realloc(node, sizeof(N4) + prefix.size());
    cur->getPrefix().length_ = prefix.size();
    memcpy(cur->getPrefix().prefix_, prefix.data_, prefix.size());
    return cur;
  }
  case NTypes::NodeType16: {
    N16 *cur = (N16 *)realloc(node, sizeof(N16) + prefix.size());
    cur->getPrefix().length_ = prefix.size();
    memcpy(cur->getPrefix().prefix_, prefix.data_, prefix.size());
    return cur;
  }
  case NTypes::NodeType48: {
    N48 *cur = (N48 *)realloc(node, sizeof(N48) + prefix.size());
    cur->getPrefix().length_ = prefix.size();
    memcpy(cur->getPrefix().prefix_, prefix.data_, prefix.size());
    return cur;
  }
  case NTypes::NodeType256: {
    N256 *cur = (N256 *)realloc(node, sizeof(N256) + prefix.size());
    cur->getPrefix().length_ = prefix.size();
    memcpy(cur->getPrefix().prefix_, prefix.data_, prefix.size());
    return cur;
  }
  }
  assert(false);
  __builtin_unreachable();
}

Prefix &N::getPrefix(N *node) {
  switch (node->getType()) {
  case NTypes::NodeType4: {
    return ((N4 *)node)->getPrefix();
  }
  case NTypes::NodeType16: {
    return ((N16 *)node)->getPrefix();
  }
  case NTypes::NodeType48: {
    return ((N48 *)node)->getPrefix();
  }
  case NTypes::NodeType256: {
    return ((N256 *)node)->getPrefix();
  }
  }
  assert(false);
  __builtin_unreachable();
}

bool N::insert(N *node, uint8_t key, N *val) {
  switch (node->getType()) {
  case NTypes::NodeType4: {
    return ((N4 *)node)->insert(key, val);
  }
  case NTypes::NodeType16: {
    return ((N16 *)node)->insert(key, val);
  }
  case NTypes::NodeType48: {
    return ((N48 *)node)->insert(key, val);
  }
  case NTypes::NodeType256: {
    return ((N256 *)node)->insert(key, val);
  }
  }
  assert(false);
  __builtin_unreachable();
}

N *N::newTreeNode(NTypes type, uint16_t level, const rocksdb::Slice &prefix) {
  switch (type) {
  case NTypes::NodeType4: {
    N4 *cur = (N4 *)malloc(sizeof(N4) + prefix.size());
    cur->init(level, prefix);
    return cur;
  }
  case NTypes::NodeType16: {
    N16 *cur = (N16 *)malloc(sizeof(N16) + prefix.size());
    cur->init(level, prefix);
    return cur;
  }
  case NTypes::NodeType48: {
    N48 *cur = (N48 *)malloc(sizeof(N48) + prefix.size());
    cur->init(level, prefix);
    return cur;
  }
  case NTypes::NodeType256: {
    N256 *cur = (N256 *)malloc(sizeof(N256) + prefix.size());
    cur->init(level, prefix);
    return cur;
  }
  default:
    assert(false);
  }
  return nullptr;
}

void N::removeAll(N *node) {
  if (node == nullptr || N::isLeaf(node))
    return;
  std::tuple<uint8_t, N *> children[256];
  uint32_t childrenCount;
  N::getChildrenSmall(node, 0, 255, children, childrenCount, 256);
  for (size_t i = 0; i < childrenCount; i++) {
    N::removeAll(std::get<1>(children[i]));
  }
  free(node);
}

} // namespace syn_art_fullkey
